from functools import partial
from typing import Optional

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass

from einops import rearrange

# constants

EfficientAttentionConfig = namedtuple(
    "EfficientAttentionConfig",
    ["enable_flash", "enable_math", "enable_mem_efficient"],
)


@dataclass
class Intermediates:
    qk_similarities: Optional[Tensor] = None
    pre_softmax_attn: Optional[Tensor] = None
    post_softmax_attn: Optional[Tensor] = None

    def to_tuple(self):
        return (
            self.qk_similarities,
            self.pre_softmax_attn,
            self.post_softmax_attn,
        )


# helpers


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


def compact(arr):
    return [*filter(exists, arr)]


def once(fn):
    called = False

    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)

    return inner


print_once = once(print)

# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)


def create_causal_mask(i, j, device):
    return torch.ones((i, j), device=device, dtype=torch.bool).triu(
        j - i + 1
    )


def onnx_create_causal_mask(i, j, device):
    r = torch.arange(i, device=device)
    causal_mask = rearrange(r, "i -> i 1") < rearrange(r, "j -> 1 j")
    causal_mask = F.pad(causal_mask, (j - i, 0), value=False)
    return causal_mask


# main class


class Attend(nn.Module):
    def __init__(
        self,
        *,
        dropout=0.0,
        causal=False,
        heads=None,
        talking_heads=False,
        sparse_topk=None,
        scale=None,
        qk_norm=False,
        flash=False,
        add_zero_kv=False,
        onnxable=False,
    ):
        super().__init__()
        self.scale = scale
        self.qk_norm = qk_norm

        self.causal = causal
        self.create_causal_mask = (
            onnx_create_causal_mask
            if onnxable
            else create_causal_mask
        )

        self.attn_fn = (
            partial(F.softmax, dtype=torch.float32)
            if not qk_norm
            else F.softmax
        )

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        # talking heads

        assert not (
            flash and talking_heads
        ), "talking heads not compatible with flash attention"

        self.talking_heads = talking_heads
        if talking_heads:
            self.pre_softmax_talking_heads = nn.Conv2d(
                heads, heads, 1, bias=False
            )
            self.post_softmax_talking_heads = nn.Conv2d(
                heads, heads, 1, bias=False
            )

        # sparse topk

        assert not (
            flash and sparse_topk
        ), "sparse topk not compatible with flash attention"
        self.sparse_topk = sparse_topk

        # add a key / value token composed of zeros
        # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html

        self.add_zero_kv = add_zero_kv

        # flash attention

        self.flash = flash
        assert not (
            flash
            and version.parse(torch.__version__)
            < version.parse("2.0.0")
        ), (
            "in order to use flash attention, you must be using"
            " pytorch 2.0 or above"
        )

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(
            torch.device("cuda")
        )

        if (
            device_properties.major == 8
            and device_properties.minor == 0
        ):
            print_once(
                "A100 GPU detected, using flash attention if input"
                " tensor is on cuda"
            )
            self.cuda_config = EfficientAttentionConfig(
                True, False, False
            )
        else:
            print_once(
                "Non-A100 GPU detected, using math or mem efficient"
                " attention if input tensor is on cuda"
            )
            self.cuda_config = EfficientAttentionConfig(
                False, True, True
            )

    def flash_attn(self, q, k, v, mask=None, attn_bias=None):
        batch, heads, q_len, _, k_len, is_cuda, device = (
            *q.shape,
            k.shape[-2],
            q.is_cuda,
            q.device,
        )

        # Recommended for multi-query single-key-value attention by Tri Dao
        # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

        if k.ndim == 3:
            k = rearrange(k, "b ... -> b 1 ...").expand_as(q)

        if v.ndim == 3:
            v = rearrange(v, "b ... -> b 1 ...").expand_as(q)

        # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention

        if self.qk_norm:
            default_scale = q.shape[-1] ** -0.5
            q = q * (default_scale / self.scale)

        # Check if mask exists and expand to compatible shape
        # The mask is B L, so it would have to be expanded to B H N L

        causal = self.causal

        if exists(mask):
            assert mask.ndim == 4
            mask = mask.expand(batch, heads, q_len, k_len)

            # manually handle causal mask, if another mask was given

            if causal:
                causal_mask = self.create_causal_mask(
                    q_len, k_len, device=device
                )
                mask = mask & ~causal_mask
                causal = False

        # handle alibi positional bias
        # convert from bool to float

        if exists(attn_bias):
            attn_bias = rearrange(
                attn_bias, "h i j -> 1 h i j"
            ).expand(batch, heads, -1, -1)

            # if mask given, the mask would already contain the causal mask from above logic
            # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number

            mask_value = -torch.finfo(q.dtype).max

            if exists(mask):
                attn_bias = attn_bias.masked_fill(
                    ~mask, mask_value // 2
                )
            elif causal:
                causal_mask = self.create_causal_mask(
                    q_len, k_len, device=device
                )
                attn_bias = attn_bias.masked_fill(
                    causal_mask, mask_value // 2
                )
                causal = False

            # scaled_dot_product_attention handles attn_mask either as bool or additive bias
            # make it an additive bias here

            mask = attn_bias

        # Check if there is a compatible device for flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=mask,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=causal,
            )

        return out, Intermediates()

    def forward(
        self, q, k, v, mask=None, attn_bias=None, prev_attn=None
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device = q.shape[-2], q.device

        scale = default(self.scale, q.shape[-1] ** -0.5)

        if self.add_zero_kv:
            k, v = map(
                lambda t: F.pad(t, (0, 0, 1, 0), value=0.0), (k, v)
            )

            if exists(mask):
                mask = F.pad(mask, (1, 0), value=True)

            if exists(attn_bias):
                attn_bias = F.pad(attn_bias, (1, 0), value=0.0)

        if self.flash:
            assert not exists(prev_attn), (
                "residual attention not compatible with flash"
                " attention"
            )
            return self.flash_attn(
                q, k, v, mask=mask, attn_bias=attn_bias
            )

        kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"

        dots = (
            einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k)
            * scale
        )

        if exists(prev_attn):
            dots = dots + prev_attn

        qk_similarities = dots.clone()

        if self.talking_heads:
            dots = self.pre_softmax_talking_heads(dots)

        if exists(attn_bias):
            dots = dots + attn_bias

        i, j, dtype = *dots.shape[-2:], dots.dtype

        mask_value = -torch.finfo(dots.dtype).max

        if exists(self.sparse_topk) and self.sparse_topk < j:
            top_values, _ = dots.topk(self.sparse_topk, dim=-1)
            sparse_topk_mask = dots < top_values[..., -1:]
            mask = (
                (mask & sparse_topk_mask)
                if exists(mask)
                else sparse_topk_mask
            )

        if exists(mask):
            dots = dots.masked_fill(~mask, mask_value)

        if self.causal:
            causal_mask = self.create_causal_mask(i, j, device=device)
            dots = dots.masked_fill(causal_mask, mask_value)

        pre_softmax_attn = dots.clone()

        attn = self.attn_fn(dots, dim=-1)
        attn = attn.type(dtype)

        post_softmax_attn = attn.clone()

        attn = self.attn_dropout(attn)

        if self.talking_heads:
            attn = self.post_softmax_talking_heads(attn)

        out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)

        intermediates = Intermediates(
            qk_similarities=qk_similarities,
            pre_softmax_attn=pre_softmax_attn,
            post_softmax_attn=post_softmax_attn,
        )

        return out, intermediates


# cascading heads logic


def to_single_heads(t, dim=1):
    heads = t.unbind(dim=dim)
    return tuple(head.unsqueeze(dim) for head in heads)


class CascadingHeads(nn.Module):
    def __init__(self, attend: Attend):
        super().__init__()
        self.attend = attend

    def forward(
        self, q, k, v, mask=None, attn_bias=None, prev_attn=None
    ):
        assert q.shape[-1] == v.shape[-1], (
            "cascading heads can only be done if query / key and"
            " value head dimensions are the same"
        )

        # split inputs into per-head inputs

        heads = q.shape[1]

        queries = to_single_heads(q)
        keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
        values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)

        mask = (mask,) * heads

        attn_bias = (
            to_single_heads(attn_bias, dim=0)
            if exists(attn_bias)
            else ((None,) * heads)
        )
        prev_attn = (
            to_single_heads(prev_attn)
            if exists(prev_attn)
            else ((None,) * heads)
        )

        # now loop through each head, without output of previous head summed with the next head
        # thus cascading

        all_outs = []
        all_intermediates = []

        prev_head_out = None

        for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(
            queries, keys, values, mask, attn_bias, prev_attn
        ):
            if exists(prev_head_out):
                h_q = h_q + prev_head_out

            out, intermediates = self.attend(
                h_q,
                h_k,
                h_v,
                mask=h_mask,
                attn_bias=h_attn_bias,
                prev_attn=h_prev_attn,
            )

            prev_head_out = out

            all_outs.append(out)
            all_intermediates.append(intermediates)

        # cat all output heads

        all_outs = torch.cat(all_outs, dim=1)

        # cat all intermediates, if they exist

        qk_similarities, pre_softmax_attn, post_softmax_attn = zip(
            *map(lambda i: i.to_tuple(), all_intermediates)
        )

        qk_similarities, pre_softmax_attn, post_softmax_attn = map(
            compact,
            (qk_similarities, pre_softmax_attn, post_softmax_attn),
        )

        aggregated_intermediates = Intermediates(
            qk_similarities=(
                torch.cat(qk_similarities, dim=1)
                if len(qk_similarities) > 0
                else None
            ),
            pre_softmax_attn=(
                torch.cat(pre_softmax_attn, dim=1)
                if len(pre_softmax_attn) > 0
                else None
            ),
            post_softmax_attn=(
                torch.cat(post_softmax_attn, dim=1)
                if len(post_softmax_attn) > 0
                else None
            ),
        )

        return all_outs, aggregated_intermediates
